{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2a9a4db9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.distributions import Normal\n",
    "import math"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "962bf3f3",
   "metadata": {},
   "source": [
    "### Bayesian Inferencing of both Mean and Precision of Gaussian likelihood\n",
    "\n",
    "We will now study the case where both the mean and the variance is unknown.\n",
    "\n",
    "We will start, as usual, with the likelihood term - expressing it in terms of the precision, $\\lambda$ (which is related to variance $\\sigma$ as $\\lambda = \\frac{ 1 } { \\sigma^{2} }$).\n",
    "The likelihood term is a Gaussian with mean $\\mu$ precision $\\lambda$\n",
    "\n",
    "$$p\\left( X \\middle\\vert \\lambda\\right) \\propto \\lambda^{ \\frac{ n }{ 2 } }  e^{ -\\frac{ \\lambda } { 2 } \\sum_{i=1}^{n} \\left( {x^{ \\left( i \\right) } - \\mu } \\right)^2 }$$\n",
    "\n",
    "We will make the prior for the mean a Gaussian with mean $\\mu_{0}$ and precision $\\lambda_{0} \\lambda$\n",
    "\n",
    "$$p\\left( \\mu \\middle\\vert \\lambda \\right) \\propto   \\lambda^{ \\frac{1}{2} }  e^{ -\\frac{ \\lambda_{0} \\lambda } { 2 }  \\left(  \\mu - \\mu_{0}  \\right)^{2} }$$\n",
    "\n",
    "We will make the prior for the precision a Gamma distribution \n",
    "$$p\\left( \\lambda \\right)  \\propto  \\lambda ^{ \\left( \\alpha_{0} - 1 \\right)} e^{ - \\beta_{0} \\lambda  }$$\n",
    "\n",
    "The overall prior function is a Normal-Gamma distribution\n",
    "\n",
    "$$p\\left( \\mu, \\lambda \\right)  \\propto  \n",
    "\\lambda^{ \\frac{1}{2} }  e^{ -\\frac{ \\lambda_{0} \\lambda } { 2 }  \\left(  \\mu - \\mu_{0}  \\right)^{2} } \n",
    "\\;\\;\\; \n",
    "\\lambda ^{ \\left( \\alpha_{0} - 1 \\right)} e^{ - \\beta_{0} \\lambda  }$$\n",
    "\n",
    "The corresponding posterior is also a Normal-Gamma distribution, such that \n",
    "$$p\\left( \\mu, \\lambda \\middle\\vert X \\right) \\propto\n",
    "e^{ -\\frac{\\lambda}{2} \\lambda_{n}  \\left( \\mu - \\mu_{n} \\right)^{2} } \\lambda^{ \\alpha_{n} - \\frac{1}{2} } e^{ -\\beta_{n} \\lambda }$$\n",
    "\n",
    "where \n",
    "$$\\mu_{n} = \\frac{ \\left( n \\bar{x} + \\mu_{0} \\lambda_{0} \\right) }{ n + \\lambda_{0} } \n",
    "\\lambda_{n} = n + \\lambda_{0} \\\\\n",
    "\\alpha_{n} = \\frac{n}{2} + \\alpha_{0} \n",
    "\\beta_{n} = \\frac{ ns }{ 2 } + \\beta_{ 0 } + \\frac{ n \\lambda_{0} } { 2 \\left( n + \\lambda_{0} \\right) } \\left( \\bar{x} - \\mu_{0} \\right)^{ 2 }$$\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "48ddb0f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Torch does not support NormalGamma distributions yet \n",
    "# So let us implement a bare bones implementation of the Normal Gamma distribution\n",
    "\n",
    "class NormalGamma():\n",
    "    def __init__(self, mu_, lambda_, alpha_, beta_):\n",
    "        self.mu_ = mu_\n",
    "        self.lambda_ = lambda_\n",
    "        self.alpha_ = alpha_\n",
    "        self.beta_ = beta_\n",
    "        \n",
    "    @property\n",
    "    def mean(self):\n",
    "        return self.mu_, self.alpha_/ self.beta_\n",
    "\n",
    "    \n",
    "    @property\n",
    "    def mode(self):\n",
    "        return self.mu_, (self.alpha_-0.5)/ self.beta_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a6bfcc8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def inference_unknown_mean_variance(X, prior_dist):\n",
    "    mu_mle = X.mean()\n",
    "    sigma_mle = X.std()\n",
    "    n = X.shape[0]\n",
    "    # Parameters of the prior\n",
    "    mu_0 = prior_dist.mu_\n",
    "    lambda_0 = prior_dist.lambda_\n",
    "    alpha_0 = prior_dist.alpha_\n",
    "    beta_0 = prior_dist.beta_\n",
    "    \n",
    "    # Parameters of posterior\n",
    "    mu_n = (n * mu_mle + mu_0 * lambda_0) / (lambda_0 + n) \n",
    "    lambda_n = n + lambda_0\n",
    "    alpha_n = n / 2 + alpha_0\n",
    "    beta_n = n / 2 * sigma_mle ** 2 + beta_0 + 0.5* n * lambda_0/(n + lambda_0) * (mu_mle - mu_0) **2 \n",
    "    posterior_dist = NormalGamma(mu_n, lambda_n, alpha_n, beta_n)\n",
    "    \n",
    "    return posterior_dist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4571fce1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let us assume that the true distribution is a normal distribution. The true distribution corresponds \n",
    "# to a single class.\n",
    "\n",
    "true_dist = Normal(20, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6e2b4c50",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True distribution: mu 20.00 std 5.00\n",
      "MLE: mu 21.17 std 0.52\n",
      "MAP: mu 19.50 std 4.79\n"
     ]
    }
   ],
   "source": [
    "# Case 1: Low data\n",
    "# Let us assume our prior is a Gamma distribution with a good estimate of the variance\n",
    "prior_dist = NormalGamma(19, 10, 1, 40)\n",
    "\n",
    "# Let us set a seed for reproducability\n",
    "torch.manual_seed(42)\n",
    "\n",
    "# Number of samples is low. \n",
    "n = 3\n",
    "X = true_dist.sample((n, 1))\n",
    "posterior_dist_low_n = inference_unknown_mean_variance(X, prior_dist)\n",
    "\n",
    "true_distribution_mu, true_distribution_std = true_dist.mean, true_dist.scale\n",
    "mle_mu, mle_std = X.mean(), X.std()\n",
    "map_mu, map_precision =  posterior_dist_low_n.mode\n",
    "map_std = math.sqrt(1 / map_precision)\n",
    "\n",
    "# When n is low, the posterior is dominated by the prior. Thus, a good prior can help offset the lack of data.\n",
    "# We can see this in the following case. \n",
    "\n",
    "# With a small sample (n=3), the MLE estimate of the standard deviation is 0.52, which is way off from the true value of 5.0\n",
    "# Using a good prior here helps offset it.\n",
    "\n",
    "print(f\"True distribution: mu {true_distribution_mu:0.2f} std {true_distribution_std:0.2f}\")\n",
    "print(f\"MLE: mu {mle_mu:0.2f} std {mle_std:0.2f}\")\n",
    "print(f\"MAP: mu {map_mu:0.2f} std {map_std:0.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d2a0cc3d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True distribution: mu 20.00 std 5.00\n",
      "MLE: mu 20.02 std 5.02\n",
      "MAP: mu 20.01 std 5.02\n"
     ]
    }
   ],
   "source": [
    "# Case 2: High data\n",
    "# Let us assume our prior is a Gamma distribution with a good estimate of the variance\n",
    "prior_dist = NormalGamma(19, 10, 1, 40)\n",
    "\n",
    "# Let us set a seed for reproducability\n",
    "torch.manual_seed(42)\n",
    "\n",
    "# Number of samples is low. \n",
    "n = 1000\n",
    "X = true_dist.sample((n, 1))\n",
    "posterior_dist_high_n = inference_unknown_mean_variance(X, prior_dist)\n",
    "\n",
    "true_distribution_mu, true_distribution_std = true_dist.mean, true_dist.scale\n",
    "mle_mu, mle_std = X.mean(), X.std()\n",
    "map_mu, map_precision =  posterior_dist_high_n.mode\n",
    "map_std = math.sqrt(1 / map_precision)\n",
    "\n",
    "# When n is high, the MLE converges to the true distribution. The MAP also converges to the MLE, and in turn \n",
    "# converges to the true distribution\n",
    "\n",
    "print(f\"True distribution: mu {true_distribution_mu:0.2f} std {true_distribution_std:0.2f}\")\n",
    "print(f\"MLE: mu {mle_mu:0.2f} std {mle_std:0.2f}\")\n",
    "print(f\"MAP: mu {map_mu:0.2f} std {map_std:0.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da3de1cb",
   "metadata": {},
   "source": [
    "### How to use the  estimated mean and variance parameters?\n",
    "\n",
    "As usual, we will obtain the maxima of the posterior probability density function $p\\left( \\mu, \\sigma \\middle\\vert X \\right) = Normal-Gamma\\left(  \\mu, \\sigma ; \\;\\; \\mu_{n}, \\lambda_{n}, \\alpha_{n}, \\beta_{n} \\right) $.\n",
    "\n",
    "This function attains its maxima when\n",
    "\n",
    "$$\\mu = \\mu_{n} \\\\\n",
    "\\lambda = \\frac{ \\alpha_{n} - \\frac{1}{2} } { \\beta_{n} }$$\n",
    "\n",
    "Thus, the probability density function for data instance $x$ belonging to the same class as the training data $X$ is $\\mathcal{N} \\left( x; \\mu_{n} ,  \\sigma_{n}  \\right)$ where $\\frac{1}{ \\sigma_{n}^{2} } = \\frac{ \\alpha_{n} - \\frac{1}{2} } { \\beta_{n} }$.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "dcb1e1ad",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MAP distribution mu: 20.01 std:5.02\n"
     ]
    }
   ],
   "source": [
    "map_mu, map_precision =  posterior_dist_high_n.mode\n",
    "map_std = math.sqrt(1 / map_precision)\n",
    "map_dist = Normal(map_mu, map_std)\n",
    "print(f\"MAP distribution mu: {map_dist.mean:0.2f} std:{map_dist.scale:0.2f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
